19d8ca
@@ -19,6 +19,7 @@
 package org.apache.hadoop.hive.ql.udf.generic;
 
 import java.util.ArrayList;
+import java.util.List;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -55,29 +56,49 @@
protected GenericUDAFRankEvaluator createEvaluator()
 		return new GenericUDAFCumeDistEvaluator();
 	}
 
-	public static class GenericUDAFCumeDistEvaluator extends GenericUDAFRankEvaluator
-	{
-		@Override
-		public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException
-		{
-			super.init(m, parameters);
-			return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
-		}
-
-		@Override
-		public Object terminate(AggregationBuffer agg) throws HiveException
-		{
-			ArrayList<IntWritable> ranks =  ((RankBuffer) agg).rowNums;
-			double sz = ranks.size();
-			ArrayList<DoubleWritable> pranks = new ArrayList<DoubleWritable>(ranks.size());
+  public static class GenericUDAFCumeDistEvaluator extends GenericUDAFRankEvaluator
+  {
+    @Override
+    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException
+    {
+      super.init(m, parameters);
+      return ObjectInspectorFactory
+          .getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
+    }
 
-			for(IntWritable i : ranks)
-			{
-				double pr = ((double)i.get())/sz;
-				pranks.add(new DoubleWritable(pr));
-			}
-
-			return pranks;
-		}
-	}
+    @Override
+    public Object terminate(AggregationBuffer agg) throws HiveException
+    {
+      List<IntWritable> ranks = ((RankBuffer) agg).rowNums;
+      int ranksSize = ranks.size();
+      double ranksSizeDouble = ranksSize;
+      List<DoubleWritable> distances = new ArrayList<DoubleWritable>(ranksSize);
+      int last = -1;
+      int current = -1;
+      // tracks the number of elements with the same rank at the current time
+      int elementsAtRank = 1;
+      for (int index = 0; index < ranksSize; index++) {
+        current = ranks.get(index).get();
+        if (index == 0) {
+          last = current;
+        } else if (last == current) {
+          elementsAtRank++;
+        } else {
+          last = current;
+          double distance = ((double) index) / ranksSizeDouble;
+          while (elementsAtRank-- > 0) {
+            distances.add(new DoubleWritable(distance));
+          }
+          elementsAtRank = 1;
+        }
+      }
+      if (ranksSize > 0 && last == current) {
+        double distance = ((double) ranksSize) / ranksSizeDouble;
+        while (elementsAtRank-- > 0) {
+          distances.add(new DoubleWritable(distance));
+        }
+      }
+      return distances;
+    }
+  }
 }
